from __future__ import print_function

import tensorflow as tf

tf = tf.compat.v1
tf.disable_eager_execution()
import cv2
import sys

sys.path.append("game/")
import game.wrapped_flappy_bird as game
import random
import numpy as np
from collections import deque

GAME = 'bird'  # 设置游戏名称
ACTIONS = 2  # 设置游戏动作数目(点击不点击屏幕)
GAMMA = 0.99  # 设置增强学习更新公式中的累计折扣因子
OBSERVE = 10000.  # 观察期1万次迭代(随机指定动作获得D)
EXPLORE = 200000.  # 探索期
FINAL_EPSILON = 0.0001  # 设置 ε的最终最小值
INITIAL_EPSILON = 0.0001  # 设置ε贪心策略中的e初始值
REPLAY_MEMORY = 500000  # 设置Replay Memory的容量
BATCH = 32  # 设置每次网络参 数更新时用的样本数目
FRAME_PER_ACTION = 1  # 设置几帧图像进行一次动作


def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.01)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.01, shape=shape)
    return tf.Variable(initial)


def conv2d(x, W, stride):
    return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME")


def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")


def createNetwork():
    # 网络权重
    W_conv1 = weight_variable([8, 8, 4, 32])
    b_conv1 = bias_variable([32])

    W_conv2 = weight_variable([4, 4, 32, 64])
    b_conv2 = bias_variable([64])

    W_conv3 = weight_variable([3, 3, 64, 64])
    b_conv3 = bias_variable([64])

    W_fc1 = weight_variable([1600, 512])
    b_fc1 = bias_variable([512])

    W_fc2 = weight_variable([512, ACTIONS])
    b_fc2 = bias_variable([ACTIONS])

    # 输入层
    s = tf.placeholder("float", [None, 80, 80, 4])

    # 隐藏层
    h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2)
    h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3)
    h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])
    h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)

    # 读出层
    readout = tf.matmul(h_fc1, W_fc2) + b_fc2

    return s, readout, h_fc1


def trainNetwork(s, readout, h_fc1, sess):
    # 定义成本函数
    a = tf.placeholder("float", [None, ACTIONS])
    y = tf.placeholder("float", [None])
    readout_action = tf.reduce_sum(tf.multiply(readout, a), reduction_indices=1)
    cost = tf.reduce_mean(tf.square(y - readout_action))
    train_step = tf.train.AdamOptimizer(1e-6).minimize(cost)

    # 打开游戏状态以与模拟器进行通信
    game_state = game.GameState()

    # 将先前的观察结果存储在重放存储器中
    D = deque()

    # 通过不执行任何操作获得第一个状态并将图像预处理为80x80x4
    do_nothing = np.zeros(ACTIONS)
    do_nothing[0] = 1
    # 处理初始图像
    x_t, r_0, terminal, score = game_state.frame_step(do_nothing)
    x_t = cv2.cvtColor(cv2.resize(x_t, (80, 80)), cv2.COLOR_BGR2GRAY)
    ret, x_t = cv2.threshold(x_t, 1, 255, cv2.THRESH_BINARY)
    s_t = np.stack((x_t, x_t, x_t, x_t), axis=2)

    # 保存和加载网络
    saver = tf.train.Saver()
    sess.run(tf.initialize_all_variables())
    checkpoint = tf.train.get_checkpoint_state("saved_networks")
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)
        print("成功加载:", checkpoint.model_checkpoint_path)
    else:
        print("找不到旧的网络权重")

    # 开始训练
    epsilon = INITIAL_EPSILON
    t = 0
    while "flappy bird" != "angry bird":
        # 根据epsilon贪心策略
        readout_t = readout.eval(feed_dict={s: [s_t]})[0]
        a_t = np.zeros([ACTIONS])
        action_index = 0
        if t % FRAME_PER_ACTION == 0:
            if random.random() <= epsilon:
                print("----------随机动作----------")
                action_index = random.randrange(ACTIONS)
                a_t[random.randrange(ACTIONS)] = 1
            else:
                action_index = np.argmax(readout_t)
                a_t[action_index] = 1
        else:
            a_t[0] = 1  # do nothing

        # 减小随机探索epsilon值
        if epsilon > FINAL_EPSILON and t > OBSERVE:
            epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE

        # 运行选定的动作并观察下一个状态和奖励
        x_t1_colored, r_t, terminal, score = game_state.frame_step(
            a_t)  # x_t1_colored是彩色图片,terminal表示是否撞到了物体,a_t表示是否执行弹跳操作
        x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)  # 将彩色图片转换为灰度图
        ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)  # 将灰度图转换为黑白图片
        x_t1 = np.reshape(x_t1, (80, 80, 1))
        s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2)

        # 将状态转移过程存储在D中(replay memory)
        D.append((s_t, a_t, r_t, s_t1, terminal))
        if len(D) > REPLAY_MEMORY:
            D.popleft()

        # 只有观察后才训练
        if t > OBSERVE:
            # 样品小批量训练
            minibatch = random.sample(D, BATCH)

            # get the batch variables
            s_j_batch = [d[0] for d in minibatch]
            a_batch = [d[1] for d in minibatch]
            r_batch = [d[2] for d in minibatch]
            s_j1_batch = [d[3] for d in minibatch]

            y_batch = []
            readout_j1_batch = readout.eval(feed_dict={s: s_j1_batch})
            for i in range(0, len(minibatch)):
                terminal = minibatch[i][4]
                # 如果终止，仅等于奖励
                if terminal:
                    y_batch.append(r_batch[i])
                else:
                    y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i]))

            # 执行渐变步骤
            train_step.run(feed_dict={
                y: y_batch,
                a: a_batch,
                s: s_j_batch}
            )

        # 更新旧的值
        s_t = s_t1
        t += 1

        # 每迭代500000次保存训练进度
        if t % 500000 == 0:
            saver.save(sess, 'saved_networks/' + GAME + '-dqn', global_step=t)

        # 打印信息
        state = ""
        if t <= OBSERVE:
            state = "observe"
        elif t > OBSERVE and t <= OBSERVE + EXPLORE:
            state = "explore"
        else:
            state = "train"

        print("迭代次数", t, "/ 状态", state,
              "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t,
              "/ Q_MAX %e" % np.max(readout_t), "分数：", score)


def playGame():
    sess = tf.InteractiveSession()
    s, readout, h_fc1 = createNetwork()
    trainNetwork(s, readout, h_fc1, sess)


def main():
    playGame()


if __name__ == "__main__":
    main()
